import os
import gc
import random
import math
import glob
import os
import copy
import base64
import ujson
import itertools
import pickle
import difflib

import numpy as np
import pandas as pd
from tqdm import tqdm
import seaborn as sns
from sympy import preview

import torch
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList

import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib.patheffects as path_effects
import matplotlib.lines as mlines
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
import imageio.v3 as iio
from io import BytesIO
from PIL import Image as PILImage

from scipy.optimize import linear_sum_assignment
from google import genai
from google.genai import types

import ipywidgets as widgets
from IPython.display import HTML, Image as IPImage, display

plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Computer Modern Roman"],
    "text.latex.preamble": "\\usepackage{amsmath} \\usepackage{soul} \\usepackage{xcolor}"
})

#==== 01 ====#
def steps_reduce(step_list, data):
    for step in step_list:
        data = step(copy.deepcopy(data))
    return data
def step_prefixes_load(data):
    with open(data["config"]["data.file_name"], 'r') as f:
        lines = ujson.load(f)
    line_from, line_to = data["config"].get("data.lines", (None, None))
    if line_to is not None: lines = lines[:line_to]
    if line_from is not None: lines = lines[line_from:]
    data["prefixes_text"] = lines
    return data

#==== 02 ====#
global_model = None
def step_model_load(data):
    global global_model
    config = data['config']
    verbose = config["verbose"]
    if torch.cuda.is_available(): data["device"] = "cuda"
    elif torch.backends.mps.is_available(): data["device"] = "mps"
    else: data["device"] = "cpu"
    if verbose: print(f'torch.device={data["device"]}')

    model_name = config["model.name"]
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    if global_model is None:
        if data["device"] == "mps":
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype="auto",
                attn_implementation="sdpa",
                quantization_config=None,
                low_cpu_mem_usage=True,
            ).to(data["device"])
            model.generation_config.use_cache  = True
            model.generation_config.cache_implementation = "static"
        elif data["device"] == "cuda":
            from transformers import BitsAndBytesConfig
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.float16,
            )
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype="auto",
                attn_implementation="sdpa",
                quantization_config=bnb_config,
                low_cpu_mem_usage=True,
                device_map="auto",
            )
        global_model = model
    assert global_model.config._name_or_path == model_name
    data["model"] = global_model
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        padding_side=config["tokenizer.padding_side"],
        truncation_side=config["tokenizer.truncation_side"],
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.decode([2])
    if verbose: print(f"tokenizer.pad_token={tokenizer.pad_token}")
    data["tokenizer"] = tokenizer
    return data
def step_tokenizer_load(data):
    config = data['config']
    verbose = config["verbose"] 
    if torch.cuda.is_available(): data["device"] = "cuda"
    elif torch.backends.mps.is_available(): data["device"] = "mps"
    else: data["device"] = "cpu"
    if verbose: print(f'torch.device={data["device"]}')
    model_name = config["model.name"]
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        padding_side=config["tokenizer.padding_side"],
        truncation_side=config["tokenizer.truncation_side"],
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.decode([2])
    if verbose: print(f"tokenizer.pad_token={tokenizer.pad_token}")
    data["tokenizer"] = tokenizer
    return data

def cache_clear(data):
    if data["device"] == "mps":
        torch.mps.empty_cache()
    elif data["device"] == "cuda":
        torch.cuda.empty_cache()
    torch._dynamo.reset()
    gc.collect()
def step_random_reset(data):
    cache_clear(data)
    config = data['config']
    random_seed = config.get("random.seed", 42)
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    return data

def step_tokenize(data):
    config = data['config']
    verbose = config["verbose"]
    tokenizer = data['tokenizer']
    prefixes_tokd = tokenizer(
        data["prefixes_text"], return_tensors="pt",
        padding=config.get("tokenizer.padding", True),
        add_special_tokens=config.get("tokenizer.add_special_tokens", True),
        max_length=config.get("tokenizer.max_length", None),
        truncation=config.get("tokenizer.truncation", False),
    ).to(data["device"])
    prefixes_tokd_len = prefixes_tokd['input_ids'].size(1)
    prefix_tokd_len_list = prefixes_tokd['attention_mask'].sum(dim=-1).tolist()
    if verbose:
        print(f"prefix token counts: {prefix_tokd_len_list} / {prefixes_tokd_len}")
    data["prefixes_tokd"] = prefixes_tokd
    data["prefixes_tokd_len"] = prefixes_tokd_len
    data["prefixes_tokd_len_list"] = prefix_tokd_len_list
    if config.get("tokenizer.output_actual_prefixes_text", False):
        data["prefixes_text_actual"] = tokenizer.batch_decode(
            prefixes_tokd['input_ids'], skip_special_tokens=True)
    return data

def step_generate(data, processors=[]):
    for k in ['tokd', 'suffixes_tokd', 'suffixes_score']:
        if k in data: del data[k]
    config = data['config']
    verbose = config.get("verbose", False)
    if 'suffixes_score_unmodified' in data: del data['suffixes_score_unmodified']
    if 'suffixes_green' in data: del data['suffixes_green']
    for k in list(data.keys()):
        if k.startswith("log."): del data[k]
    for p in processors: p.data = data
    with torch.no_grad():
        token_count = config["model.token_count"]
        gen_output = data["model"].generate(
            **data["prefixes_tokd"],
            logits_processor=LogitsProcessorList(processors),
            min_new_tokens=token_count, max_new_tokens=token_count,
            do_sample=config.get("model.do_sample", False),
            num_beams=config.get("model.num_beams", 1),
            use_cache=config.get("model.use_cache", True),
            top_k=config.get("model.top_k", None),
            top_p=config.get("model.top_p", None),
            temperature=config.get("model.temperature", None),
            repetition_penalty=config.get("model.repetition_penalty", 1.25),
            output_scores=config.get("model.output_scores", False),
            return_dict_in_generate=True,
        )
    data['suffixes_tokd'] = gen_output['sequences'][:, data['prefixes_tokd_len']:]
    if verbose: print(f"suffix tensor shape: {data['suffixes_tokd'].shape}")
    if 'scores' in gen_output:
        data['suffixes_score'] = torch.stack(gen_output['scores'], dim=1)
        if verbose: print(f"suffix score shape: {data['suffixes_score'].shape}")
    return data
def step_decode(data):
    tokenizer = data['tokenizer']
    data['suffixes_text'] = tokenizer.batch_decode(
        data['suffixes_tokd'], skip_special_tokens=True)
    data['prefixes_text_trimmed'] = tokenizer.batch_decode(
        data['prefixes_tokd']['input_ids'], skip_special_tokens=True)
    return data

#==== 03 ====#
def fn_colors_fixed(tokd_prev, data):
    vocab_size = data["vocab_size"]
    color_count = len(data["config"]["watermark.colors"])
    rng = torch.Generator(tokd_prev.device)
    rng.manual_seed(15485863)
    rand = torch.rand((vocab_size,), generator=rng, device=tokd_prev.device)
    return (rand * color_count).floor().int().repeat(tokd_prev.size(0), 1)

def fn_colors_random_slow(tokd_prev, data):
    vocab_size = data["vocab_size"]
    color_count = len(data["config"]["watermark.colors"])
    rng = torch.Generator(tokd_prev.device)
    green = []
    for token in tokd_prev[:,-1]:
        rng.manual_seed(3779 * token.item())
        rand = torch.rand((vocab_size,), generator=rng, device=tokd_prev.device)
        green.append(torch.floor(rand * color_count))
    return torch.stack(green, dim=0).int()

def fn_colors_popular_alternating(tokd_prev, data):
    color_count = len(data["config"]["watermark.colors"])
    colors = data['suffixes_logits_unmodified'][:,-1,:]
    colors = torch.argsort(colors, dim=-1, descending=True)
    colors = torch.argsort(colors, dim=-1)
    colors = (colors + tokd_prev[:,-1,None]) % color_count
    return colors

def fn_pattern(tokd_prev, data):
    pos = tokd_prev.size(1)
    wm_colors = data["config"]["watermark.colors"]
    wm_pattern = data["config"]["watermark.color_pattern"]
    assert min(wm_pattern) >= 0 and max(wm_pattern) <= len(wm_colors)
    col_idx = wm_pattern[pos % len(wm_pattern)]
    return torch.full(
        (tokd_prev.size(0), 1), col_idx,
        device=tokd_prev.device)

def tensor_cat(data, key, new_row):
    row_clone = new_row.clone().unsqueeze(1)
    if key not in data: data[key] = row_clone
    else: data[key] = torch.cat((data[key], row_clone,), dim=1)
class Perturber(LogitsProcessor):
    def __call__(self, tokd_prev, logits):
        self.data["vocab_size"] = logits.shape[-1]
        # tensor_cat(self.data, 'suffixes_logits_unmodified', logits)
        config = self.data["config"]
        # find out the vocabulary partitioning for this step
        colors = config["watermark.fn_colors"](
            tokd_prev=tokd_prev, data=self.data)
        # perturb
        delta = self.data['config']['watermark.delta']
        color_curr = self.data["config"]["watermark.fn_pattern"](
            tokd_prev=tokd_prev[:, self.data["prefixes_tokd_len"]:],
            data=self.data)
        logits += delta * (colors == color_curr)
        # record color and logits
        # tensor_cat(self.data, 'suffixes_colors', colors)
        # tensor_cat(self.data, 'suffixes_logits', logits)
        return logits
def step_generate_watermarked(data):
    for k in ['suffixes_logits_unmodified', 'suffixes_logits', 'suffixes_colors']:
        if k in data: del data[k]
    return step_generate(data, [Perturber()])

def str_escape_for_latex(text):
    return repr(text.replace("▁", " ")).encode('unicode-escape').decode('ascii')\
        .replace("\\", "\\textbackslash ")\
        .replace("&", "\\&").replace("$", "\\$")\
        .replace("{", "\\{").replace("}", "\\}")\
        .replace('\n', '')[1:-1]
def latex_text_color(datas, title, row_idx, height=128, subscript=False):
    latex = "\\begin{figure*}[htbp]" if title is not None else ""
    for data in datas:
        delta = data['config']['watermark.delta']
        label = f"$\\delta$={delta}"
        if delta > 0:
            pattern_text = ''.join([str(i) for i in data['config']['watermark.color_pattern']])
            label += f", pattern={pattern_text}"
        latex += f"\\fcolorbox{{black!0}}{{black!0}}{{\n{label}}}\n\n"
        latex += f"\\fcolorbox{{black!0}}{{black!15}}{{\n"
        latex += f"\\begin{{minipage}}{{\\linewidth}}\n"
        latex += f"\\ldots {str_escape_for_latex(data['prefixes_text'][row_idx])}"
        for tokd_idx, (tokd, colors) in\
                enumerate(zip(data["suffixes_tokd"][row_idx], data["suffixes_colors"][row_idx])):
            text = str_escape_for_latex(data["tokenizer"].convert_ids_to_tokens([tokd])[0])
            if subscript: text += f"$_{{{tokd_idx}}}$"
            col = data["config"]["watermark.colors"][colors[tokd].item()]
            latex += f"{{\sethlcolor{{{col}}}\hl{{{text}}}}}"
        latex += f"\\ldots \n\\end{{minipage}} }} \n\n"
    if title is not None:
        latex += f"\\caption{{{title}}} \\end{{figure*}}"
    buf = BytesIO()
    preview(latex, output='png', viewer='BytesIO', outputbuffer=buf,
            preamble=("\\documentclass[12pt]{article}"
                    "\\usepackage{amsmath}" "\\usepackage{soul}" "\\usepackage{xcolor}"
                    "\\begin{document}"))
    buf.seek(0)
    image_bytes = buf.read()
    buf.seek(0)
    img = PILImage.open(BytesIO(image_bytes))
    img_top = img.crop((0, 0, img.size[0], height*len(datas)))
    display(img_top)
    return latex

#==== 04 ====#
def step_ppl_compute(data):
    with torch.no_grad():
        tokd_full = torch.concat((data['prefixes_tokd']['input_ids'], data['suffixes_tokd'],), dim=-1)
        # ignore padding tokens
        tokd_full = tokd_full.clone().detach()
        tokd_full[tokd_full == data["tokenizer"].pad_token_id] = -100
        # ignore prefix tokens
        tokd_suffix = tokd_full.clone().detach()
        tokd_suffix[:, :data['prefixes_tokd_len']] = -100
        # compare the mode's logits for tokd_full against tokd_suffix
        outputs = data["model"](input_ids=tokd_full, labels=tokd_suffix)
        loss = outputs.loss
        data["eval.loss"] = loss.item()
        ppl = torch.exp(loss).clone().detach()
        data["eval.ppl"] = ppl.item()
    return data

def step_ppl_compute_line_slow(data):
    data["eval.loss_per_line"] = []
    data["eval.ppl_per_line"] = []
    model = data["model"]
    model.eval()
    # Different evaluation configurations
    encodings = torch.concat((data['prefixes_tokd'].input_ids, data['suffixes_tokd'],), dim=-1)
    encodings_masked = encodings.clone().detach()
    encodings_masked[encodings_masked == data["tokenizer"].pad_token_id] = -100
    encodings_masked[:, :data['prefixes_tokd_len']] = -100
    # Compare perplexity across contexts
    for encoding, encoding_masked in zip(encodings, encodings_masked):
        with torch.no_grad():
            outputs = model(encoding, labels=encoding_masked)
            loss = outputs.loss
            perplexity = torch.exp(loss)
        data["eval.loss_per_line"].append(loss)
        data["eval.ppl_per_line"].append(perplexity)
    data["eval.loss_per_line"] = torch.stack(data["eval.loss_per_line"])
    data["eval.ppl_per_line"] = torch.stack(data["eval.ppl_per_line"])
    data["eval.loss"] = data["eval.loss_per_line"].mean().item()
    data["eval.ppl"] = data["eval.ppl_per_line"].mean().item()
    return data

def step_ppl_compute_line_fast(data):
    prefixes_tokd_len = data['prefixes_tokd_len']
    prefixes_tokd = data['prefixes_tokd']['input_ids']
    suffixes_tokd = data['suffixes_tokd']
    line_count = prefixes_tokd.shape[0]

    tokd_actual = torch.concat((prefixes_tokd, suffixes_tokd), dim=-1)
    tokd_actual = tokd_actual.clone().detach()
    tokd_actual[tokd_actual == data["tokenizer"].pad_token_id] = -100

    tokd_suffix = tokd_actual.clone().detach()
    if prefixes_tokd_len > 0: tokd_suffix[:, :prefixes_tokd_len] = -100

    with torch.no_grad():
        outputs = data["model"](input_ids=tokd_actual)
        # shift logits and labels, so we can prefict the label[i] using logit[i-1]
        shifted_logits = outputs.logits[..., :-1, :].contiguous()
        shifted_labels = tokd_suffix[..., 1:].contiguous()
        loss_flat = F.cross_entropy(
            shifted_logits.view(-1, shifted_logits.size(-1)),
            shifted_labels.view(-1),
            reduction='none',
        )
        loss = loss_flat.view(line_count, -1)
        # ignore -100 labels in the loss/ppl computation
        mask = shifted_labels != -100
        loss_masked = loss * mask
        loss_per_line = loss_masked.sum(dim=1) / mask.sum(dim=1)
        data["eval.loss_per_line"] = loss_per_line
        ppl_per_line = torch.exp(loss_per_line)
        data["eval.ppl_per_line"] = ppl_per_line

        data["eval.loss"] = loss_per_line.mean().item()
        data["eval.ppl"] = ppl_per_line.mean().item()
    return data

def plot_ppl_line_by_line(datas, labels, inner=None,
                          title=None, boxsize=2, ylim=(2, 10)):
    if title is not None: plt.title(title)
    ppls = [d["eval.ppl_per_line"].tolist() for d in datas]
    density_norm = "area" # area, count, width
    # background and sticks
    sns.violinplot(data=ppls, inner=inner,
        density_norm=density_norm,
        linewidth=0, linecolor="white",
        inner_kws={'linewidth': 2, 'alpha': .05},
        zorder=-1, alpha=.9,
    )
    # box plot inner=box/quartile/point/stick/None
    sns.violinplot(data=ppls, inner="box",
        density_norm=density_norm, alpha=0,
        linewidth=boxsize, linecolor="black",
        zorder=-2,
    )
    plt.ylim(*ylim)
    plt.ylabel('PPL')
    plt.grid(zorder=0, color='black', alpha=.2, linestyle="--", linewidth=1)
    plt.xticks(np.arange(len(labels)), labels=labels)

#==== 05 ====#
def compute_tokd_color(data):
    result = {}
    config = data['config']
    # find out about the token colors
    full_tokd = torch.cat(
        (data['prefixes_tokd']['input_ids'],  data['suffixes_tokd']),
        dim=1)
    for i in range(data['prefixes_tokd_len'], full_tokd.size(1)):
        tokd = full_tokd[:, i]
        colors = config["watermark.fn_colors"](
            tokd_prev=full_tokd[:, :i], data=data)
        tokd_col = colors[torch.arange(tokd.size(0)), tokd]
        tensor_cat(result, 'suffixes_tokd_color', tokd_col)
        # tensor_cat(result, 'suffixes_colors', colors)
    if config["verbose"]: print(f"tokd_color:\n{result['suffixes_tokd_color']}")
    return result

def step_score_watermark(data):
    for k, v in compute_tokd_color(data).items(): data[k] = v
    config = data['config']
    w = config['detection.window']
    wm_pattern = config['watermark.color_pattern']
    assert min(wm_pattern) >= 0 and max(wm_pattern) < 10, "We don't support more than 10 colors atm"
    wm_pattern_with_rots = ''.join([str(i) for i in wm_pattern]) * w
    # everything below here happens on CPU not GPU
    tokd_color = data['suffixes_tokd_color'].cpu()
    scores = torch.zeros((tokd_color.size(0), tokd_color.size(1) - w + 1,))
    for x in range(scores.size(1)):
        for y in range(scores.size(0)):
            window = tokd_color[y, x:x+w]
            window = ''.join([str(i) for i in window.tolist()])
            scores[y, x] = 1 if window in wm_pattern_with_rots else 0
    if config["verbose"]: print(f"scores:\n{scores}")
    data['wm_detection.scores'] = torch.cumsum(scores, dim=-1)
    data['wm_detection.scores'] /= data['wm_detection.scores'].size(1)
    return data

def plt_score_detection(datas, labels, colors, alpha=None):
    if alpha is None:
        alpha = 2.0 / datas[0]["wm_detection.scores"].size(0)
    for data, label, color in zip(datas, labels, colors):
        score = data["wm_detection.scores"]
        plt.plot(list(range(score.size(1))), score.numpy().T,
                color=color, alpha=alpha, label=label,
                linewidth=2, marker='.')
        plt.xlim(0, score.size(1)-1)
        plt.axhline(0, color='black', linewidth=.5)
        plt.xlabel("$t$")
        plt.ylim(0, 1)
    
        handles, labels = plt.gca().get_legend_handles_labels()
        unique = dict(zip(labels, handles))
        leg = plt.legend(unique.values(), unique.keys())
        for lh in leg.legend_handles:
            lh.set_alpha(.5)
        plt.xticks(np.arange(0, score.size(1)), [])
        plt.grid(True, alpha=alpha)
    plt.tight_layout()

#==== 06 ====#
def step_detect_watermark(data):
    score = data["wm_detection.scores"]
    th = data["config"]["detection.threshold"]
    data["detection.is_wmkd"] = (score[:, -1] > th).int()
    return data

def _eval_is_wmked(data, expected, detection_threshold):
    if 'config' not in data: data['config'] = {}
    data["config"]["detection.threshold"] = detection_threshold
    data = step_detect_watermark(data)
    count_incorrect = torch.sum(data["detection.is_wmkd"] != expected).item()
    count_total = data["detection.is_wmkd"].size(0)
    error = 1.0 * count_incorrect / count_total
    return error
def eval_is_wmked(data_unmod, data_wtmkd, detection_threshold):
    t1 = _eval_is_wmked(data_unmod, 0, detection_threshold)
    t2 = _eval_is_wmked(data_wtmkd, 1, detection_threshold)
    return t1, t2

def bsearch_wmk_detection_threshold(data_unmod, data_wtmkd, t1_err_max):
    EPS = 1.0 / data_unmod['suffixes_tokd'].size(1)
    for th in np.arange(-EPS, 1+EPS*2, EPS):
        t1, t2 = eval_is_wmked(data_unmod, data_wtmkd, th)
        if t1 <= t1_err_max: break
    return th, t1, t2

def sweep_for_watermark_detection(
        data, data_key, data_vals, batch_count=5, batch_size=64, return_edited_datas=False):
    data = steps_reduce([step_model_load, step_random_reset], data)
    df, result_datas = [], []

    for batch_idx in range(batch_count):
        data["config"]["data.lines"] = (batch_size*batch_idx, batch_size*(batch_idx+1))
        data_unmod = steps_reduce([step_prefixes_load, step_tokenize], data)
        data_unmod["config"]["watermark.delta"] = 0
        data_unmod = steps_reduce([step_generate_watermarked, step_ppl_compute], data_unmod)

        tqdm_indicator = tqdm(data_vals)
        for data_val_name, data_val in tqdm_indicator:
            tqdm_label = f"batch_idx={batch_idx}, {data_key}={data_val_name}"
            data_unmod = copy.deepcopy(data_unmod)
            data_wmked = copy.deepcopy(data_unmod)
            del data_wmked["suffixes_tokd"]

            tqdm_indicator.set_description(f"{tqdm_label} GENERATEING")
            data_wmked["config"][data_key] = data_val
            data_wmked = steps_reduce([step_generate_watermarked,
                                       step_score_watermark], data_wmked)
    
            tqdm_indicator.set_description(f"{tqdm_label} ANALYZING")
            data_unmod["config"][data_key] = data_val
            data_unmod = steps_reduce([step_score_watermark], data_unmod)

            tqdm_indicator.set_description(f"{tqdm_label} BSEARCHING")
            th, e1, e2 = bsearch_wmk_detection_threshold(
                data_unmod, data_wmked, data_wmked["config"]["detection.target_t1"])
            
            tqdm_indicator.set_description(f"{tqdm_label} PERPLEXING")
            data_wmked = steps_reduce([step_ppl_compute], data_wmked)
            df.append({
                "batch_idx": batch_idx,
                "val": data_val_name,
                "detection.threshold": th,
                "detection.t1": e1,
                "detection.t2": e2,
                "ppl.base": data_unmod["eval.ppl"],
                "ppl.wmked": data_wmked["eval.ppl"],
            })
            if return_edited_datas:
                result_datas.append(data_unmod)
                result_datas.append(data_wmked)
            cache_clear(data)
            tqdm_indicator.set_description(f"batch_idx={batch_idx} DONE")
    df = pd.DataFrame(df)
    df = df.sort_values(by=['val'])
    if data["config"].get("log", False): display(df)
    df = df.groupby(['val']).mean()
    df = df.drop(columns=['batch_idx'])
    if not return_edited_datas: return df
    return df, result_datas

def plt_legend_unique(loc=None, fontsize=None):
    handles, labels = plt.gca().get_legend_handles_labels()
    unique = dict(zip(labels, handles))
    plt.legend(unique.values(), unique.keys(), loc=loc, fontsize=fontsize)
def plt_watermark_sweep_single(
        df, label, color, y_key, y_label, ppl_base, xlim, ylim,
        annotation_alpha=.1, annotation_offset=(0, 10)):
    plt.axvline(ppl_base, linestyle='--',
                color='tab:red', linewidth=2,
                label="Base PPL")
    plt.plot(df["ppl.wmked"], df[y_key],
             linewidth=2, marker='*', linestyle=':',
             label=label, color=color)
    
    for i in range(len(df)):
        plt.annotate(
            df.index[i],
            (df["ppl.wmked"][i], df[y_key][i]),
            textcoords="offset points", xytext=annotation_offset, ha='center',
            color='black', alpha=annotation_alpha)
    plt.ylabel(y_label)
    plt.ylim(*ylim)
    plt.axhline(0, color='black', linewidth=1)
    plt.xlabel("PPL (better ←)")
    plt.xlim(*xlim)
    plt.grid(True)
    plt_legend_unique()
def plt_watermark_sweep(dfs, labels=None, colors=None, annotation_alpha=1):
    if labels is None: labels = ["???"] * len(dfs)
    if colors is None: colors = [None] * len(dfs)
    ppl_base = torch.tensor([df["ppl.base"].mean() for df in dfs]).mean()
    xlim = ppl_base-.25, max([df["ppl.wmked"].max() for df in dfs])+.25
    annotation_offset = [5, 5]
    for df, label, color in zip(dfs, labels, colors):
        annotation_offset[0] *= -1
        plt.subplot(1, 3, 3)
        plt_watermark_sweep_single(
            df, label, color, "detection.t1", "Type I Error (better ←)",
            ppl_base, xlim, (0, .6),
            annotation_alpha=annotation_alpha, annotation_offset=annotation_offset)
        plt.subplot(1, 3, 2)
        plt_watermark_sweep_single(
            df, label, color, "detection.t2", "Type II Error (better ←)",
            ppl_base, xlim, (0, .6),
            annotation_alpha=annotation_alpha, annotation_offset=annotation_offset)
    plt.tight_layout()

#==== 07 ====#
def step_score_edit(data):
    for k, v in compute_tokd_color(data).items(): data[k] = v
    config = data['config']
    w = config['edit_detection.window']
    wm_pattern = config['watermark.color_pattern']
    assert min(wm_pattern) >= 0 and max(wm_pattern) < 10, "We don't support more than 10 colrs atm"
    wm_pattern_with_rots = math.ceil(2*w/len(wm_pattern)) * wm_pattern
    wm_pattern_with_rots = ''.join([str(i) for i in wm_pattern_with_rots])
    # everything below here happens on CPU not GPU
    tokd_color = data['suffixes_tokd_color'].cpu()
    edit_scores = torch.zeros((tokd_color.size(0), tokd_color.size(1),))
    for y in range(edit_scores.size(0)):
        for t in range(edit_scores.size(1)):
            ss = []
            for w_idx in range(-w+1, 1):
                x = t + w_idx
                if x < 0 or x+w > edit_scores.size(1):
                    ss.append(1)
                    continue
                window = tokd_color[y, x:x+w]
                window = ''.join([str(i) for i in window.tolist()])
                ss.append(1 if window in wm_pattern_with_rots else 0)
            edit_scores[y, t] = 1.0 * sum(ss) / len(ss)
    if config["verbose"]: print(f"scores:\n{edit_scores}")
    data['edit_detection.scores'] = edit_scores
    return data

def step_edit_tokd_randomize(data):
    w = data['config']["edit_detection.window"]
    tokenizer = data["tokenizer"]
    T = data['suffixes_tokd']
    edits_shape = (T.size(0), data["config"]['edits.count'],)
    # random edit positions
    P = torch.stack([torch.randperm(T.size(1)-w-w) for _ in range(T.size(0))])
    P = P[:, :edits_shape[1]]+w
    # random values
    V = torch.randint(
        low=0, high=tokenizer.vocab_size, size=edits_shape).to(T.device)
    # apply values to positions
    row_indices = torch.arange(T.size(0)).unsqueeze(1)
    T[row_indices, P] = V
    data['edits.positions'] = P
    # write out edit mask
    output_mask = torch.zeros(T.shape, dtype=torch.int)
    row_indices = torch.arange(P.size(0)).unsqueeze(1).repeat(1, P.size(1)).flatten()
    col_indices = P.flatten()
    output_mask[row_indices, col_indices] = True
    data['edits.mask'] = output_mask
    return data
def plt_scores_with_edits(datas, labels, colors, alpha=.25):
    plt_score_detection(datas, labels, colors, alpha)
    for data, label, color in zip(datas, labels, colors):
        if 'edits.positions' in data:
            xs = data['edits.positions'].cpu()
            wm_score = data["wm_detection.scores"]
            ys = wm_score[torch.arange(wm_score.size(0)).unsqueeze(-1), xs].cpu()
            plt.scatter(xs, ys, marker='x', color='black',
                        alpha=alpha*2, linewidths=3, label="edits")
    max_x = max([d['wm_detection.scores'].size(1) for d in datas])
    plt.xlim(0, max_x-1)
    plt.tight_layout()

def _eval_stats(tp, tn, fp, fn):
    assert (tp+tn+fp+fn).sum() == tp.size(0)*tp.size(1)
    # scalar values
    tp = tp.sum(dim=-1)
    tn = tn.sum(dim=-1)
    fp = fp.sum(dim=-1)
    fn = fn.sum(dim=-1)
    # calculate stats
    t1 = torch.nan_to_num(fp / (fp + tn), nan=0)
    t2 = torch.nan_to_num(fn / (fn + tp), nan=0)
    prec = torch.nan_to_num(tp / (tp + fp), nan=0)
    recall = torch.nan_to_num(tp / (tp + fn), nan=0)
    f1 = torch.nan_to_num(2 * (prec * recall) / (prec + recall), nan=0)
    return torch.stack((t1, t2, prec, recall, f1,)).T
def _eval_classification_stats(A, D):
    tp = torch.logical_and(A, D).int()
    tn = torch.logical_not(torch.logical_or(A, D)).int()
    fp = torch.logical_and(torch.logical_not(A), D).int()
    fn = torch.logical_and(A, torch.logical_not(D)).int()
    return tp, tn, fp, fn, _eval_stats(tp, tn, fp, fn)
def _eval_edit_detection_slow(A, D, t):
    assert isinstance(t, tuple) and len(t) == 2, "tolerance is not a list"
    tp = torch.zeros_like(D, device=D.device)
    tn = torch.zeros_like(D, device=D.device)
    fp = torch.zeros_like(D, device=D.device)
    fn = torch.zeros_like(D, device=D.device)
    for y in range(D.size(0)):
        for x in range(D.size(1)):
            if A[y, x]: # actual edit
                detected = False
                for xx in range(max(0, x+t[0]), min(D.size(1), x+t[1] + 1)):
                    if D[y, xx]: detected = True
                if detected: tp[y, x] = 1
                else: fn[y, x] = 1
            else: # no edit
                if D[y, x] == 0:
                    tn[y, x] = 1
                else:
                    tolerate = False
                    for xx in range(max(0, x+t[0]), min(D.size(1), x+t[1] + 1)):
                        if A[y, xx]: tolerate = True
                    if tolerate: tn[y, x] = 1
                    else: fp[y, x] = 1
    return tp, tn, fp, fn, _eval_stats(tp, tn, fp, fn)
def _eval_edit_detection_fast(A, D, t):
    assert isinstance(t, tuple) and len(t) == 2, "tolerance is not a list"
    W = D.size(1)
    D_with_t = torch.zeros_like(D, device=D.device)
    for n_offset in range(t[0], t[1] + 1):
        if n_offset == 0:
            D_with_t += D
        elif n_offset < 0:
            D_with_t[:, -n_offset:] += D[:, :W+n_offset]
        else: # n_offset > 0
            D_with_t[:, :W-n_offset] += D[:, n_offset:]
    return _eval_classification_stats(A, D_with_t)

def step_eval_edit_detection_with_tolerance(data):
    score_edit = data["edit_detection.scores"]
    edits_mask_empty = torch.zeros_like(score_edit).int()
    A = (data.get("edits.mask", edits_mask_empty)).to(score_edit.device)
    D = score_edit <= data["config"]["edit_detection.threshold"]
    t = data["config"]["edit_detection.tolerance"]
    data["edit_detection.eval"] = _eval_edit_detection_slow(A, D, t)
    return data

def plot_edit_score(data, row):
    plt.axhline(0, color='black', linewidth=1, zorder=-5)
    th = data["config"]["edit_detection.threshold"]\
        if "edit_detection.threshold" in data["config"] else 0.5
    ys = data["edit_detection.scores"][row] - th
    xs = torch.arange(ys.size(0))
    plt.ylabel("$E(i)$")
    plt.yticks([])
    plt.xticks(np.arange(0, ys.size(0), 1), labels=[])
    plt.xlim(-1, ys.size(0))
    # bars
    plt.bar(xs, ys, label="$>$ Threshold",
            color='silver', zorder=-4)
    ys_detected = ys < 0
    plt.bar(xs[ys_detected], ys[ys_detected], label="$\\leq$ Threshold",
            color='tab:green', zorder=-3)
    # edit detection eval
    tp, tn, fp, fn, st = data["edit_detection.eval"]
    tp = torch.where(tp[row])[0]
    plt.scatter(tp, torch.zeros_like(tp), label=f"TP={tp.size(0)}", marker='D',
                color='white', zorder=-2, linewidths=1, edgecolor='black')
    tn = torch.where(tn[row])[0]
    plt.scatter(tn, torch.zeros_like(tn), label=f"TN={tn.size(0)}", marker='o',
                color='black', zorder=-2, linewidths=1, edgecolor='black')
    fp = torch.where(fp[row])[0]
    plt.scatter(fp, torch.zeros_like(fp), label=f"FP={fp.size(0)}", marker='D',
                color='red', zorder=-2, linewidths=1, edgecolor='black')
    fn = torch.where(fn[row])[0]
    plt.scatter(fn, torch.zeros_like(fn), label=f"FN={fn.size(0)}", marker='o',
                color='gold', zorder=-2, linewidths=1, edgecolor='black')
    # edit positions
    if 'edits.positions' in data:
        x = data['edits.positions'].cpu()[row]
        plt.scatter(x, [ys.max().item()/3]*len(x), zorder=-1, linewidths=3,
                    label="Actual edits", color='tab:green', marker='x')
    # title
    edit_count = 0
    if 'edits.count' in data['config']:
        if hasattr(data['config']['edits.count'], '__getitem__'):
            edit_count = data['config']['edits.count'][row]
        else:
            edit_count = data['config']['edits.count']
    st = st[row]
    t = data["config"]["edit_detection.tolerance"]
    stats_text = f", tolerance={t} threshold={th:.2f} $\\rightarrow$ $type1={st[0]*100:.1f}\\%, type2={st[1]*100:.1f}\\%$"
    plt.title(f"{edit_count}/{len(xs)} edits{stats_text}")
    # legend
    plt.legend(ncol=4, fontsize='smaller', loc='upper left')
    plt.tight_layout()

def step_bsearch_edit_detection_threshold_given_t1(data):
    t1_target = data["config"]["edit_detection.target_t1"]
    EPS = .0001
    low = 0 - EPS
    high = 1
    prev_low, prev_t1 = None, None
    while abs(low - high) > EPS or not "edit_detection.eval" in data:
        mid = (low + high) / 2.0
        data["config"]["edit_detection.threshold"] = mid
        data = step_eval_edit_detection_with_tolerance(data)
        t1, t2, _, _, _ = data["edit_detection.eval"][-1].mean(dim=0)
        if data["config"]["verbose"]: print(mid, t1, t2)
        if t1 <= t1_target: low = mid
        else: high = mid
        if prev_t1 is None or prev_t1 != t1: prev_low, prev_t1 = low, t1
    data["config"]["edit_detection.threshold"] = (prev_low + low) / 2
    return data

#==== 08 ====#
def _data_insert(data, key, P, V):
    T = data[key]
    target = torch.zeros((T.size(0), T.size(1)+P.size(1)), dtype=T.dtype, device=T.device)
    place_original_mask = torch.ones(target.shape, dtype=torch.bool, device=T.device)
    place_original_mask.scatter_(dim=1, index=P, value=False)
    target[place_original_mask] = T.flatten()
    target.scatter_(dim=1, index=P, src=V)
    data[key] = target
    return data[key]
def _data_replace(data, key, P, V):
    T = data[key]
    T.scatter_(dim=1, index=P, src=V)
    return data[key]
def _data_remove(data, key, P):
    T = data[key]
    keep_mask = torch.ones_like(T, dtype=torch.bool)
    keep_mask.scatter_(dim=1, index=P, value=False)
    T = T[keep_mask]
    T = T.reshape(keep_mask.size(0), keep_mask.size(1)-P.size(1))
    P = torch.clamp(P[:,:1], max=T.size(1)-1)
    data[key] = T
    return data[key], P
def step_edit_tokd_randomize_contiguous(data):
    edits_type = data["config"]['edits.type']
    edits_count = data["config"]['edits.count']
    edits_buffer = data["config"]['edits.buffer']
    vocab_range = (0, data["tokenizer"].vocab_size)
    T = data['suffixes_tokd']
    V = torch.randint(
        low=vocab_range[0],
        high=vocab_range[1],
        size=(T.size(0), edits_count,),
        dtype=T.dtype, device=T.device)
    P = torch.randint(
        low=+edits_buffer,
        high=-edits_buffer+T.size(1)+1 - (edits_count if edits_type != "INSERT" else 0),
        size=(T.size(0), 1,),
        dtype=torch.long, device=T.device) + torch.arange(V.size(1), device=T.device)
    if edits_type == "INSERT":
        T = _data_insert(data, 'suffixes_tokd', P, V)
    elif edits_type == "REPLACE":
        T = _data_replace(data, 'suffixes_tokd', P, V)
    elif edits_type == "REMOVE":
        T, P = _data_remove(data, 'suffixes_tokd', P)
    data['suffixes_tokd'] = T
    data['edits.positions'] = P
    data["config"]["model.token_count"] = T.size(1)
    mask = torch.zeros(T.shape, device=T.device).long()
    mask.scatter_(dim=1, index=P, value=True)
    data['edits.mask'] = mask
    for k in ["suffixes_colors", "suffixes_logits",
              "suffixes_logits_unmodified"]:
        if k in data: del data[k]
    return data

def plot_edit_score(datas, row, wm_score_alpha=.75):
    max_token_idx = max([d["edit_detection.scores"].size(1) for d in datas])
    y_min, y_max = -.75, .5
    for data in datas:
        th = data["config"]["edit_detection.threshold"]\
            if "edit_detection.threshold" in data["config"] else 0.5
        ys = data["edit_detection.scores"][row] - th
        y_min = min(y_min, ys.min().item()-.75)
        y_max = max(y_max, ys.max().item()+.5)
    for data_idx, data in enumerate(datas):
        plt.subplot(len(datas), 1, data_idx+1)
        th = data["config"]["edit_detection.threshold"]\
            if "edit_detection.threshold" in data["config"] else 0.5
        # labels and limits
        plt.ylim(y_min, y_max)
        plt.axhline(0, color='black', linewidth=1, zorder=-5)
        ys = data["edit_detection.scores"][row] - th
        xs = torch.arange(ys.size(0))
        # plt.ylabel("$|s|_E(t)$")
        plt.yticks(torch.zeros((1,)), [f"{th:.2f}"])
        plt.xticks(np.arange(0, ys.size(0), 5))
        plt.xlim(-.5, max_token_idx-.5)
        # colors
        tokd_is_green = data['suffixes_tokd_color'][row]
        colors = [data["config"]["watermark.colors"][c] for c in tokd_is_green]
        for y in [min(-1, y_min), max(+1, y_max)]:
            plt.bar(xs, [y*10]*xs.size(0), zorder=-5, width=1, alpha=.1,
                    color=colors,)
        # edit scores
        plt.bar(xs, ys, label="$> \\tau_e$",
                color='darkgrey', zorder=-4, width=.7)
        ys_detected = ys < 0
        plt.bar(xs[ys_detected], ys[ys_detected], label="$\\leq \\tau_e$",
                color='steelblue', zorder=-3, width=.7)
        # evaluation
        if "edit_detection.eval" in data:
            tp, tn, fp, fn, st = data["edit_detection.eval"]
            tp = torch.where(tp[row])[0]
            plt.scatter(tp, torch.zeros_like(tp), label=f"TP={tp.size(0)}", marker='D',
                        color='white', zorder=-2, linewidths=1, edgecolor='black', alpha=.5)
            tn = torch.where(tn[row])[0]
            plt.scatter(tn, torch.zeros_like(tn), label=f"TN={tn.size(0)}", marker='o',
                        color='black', zorder=-2, linewidths=1, edgecolor='black', alpha=.5)
            fp = torch.where(fp[row])[0]
            plt.scatter(fp, torch.zeros_like(fp), label=f"FP={fp.size(0)}", marker='D',
                        color='magenta', zorder=-2, linewidths=1, edgecolor='black', alpha=.5)
            fn = torch.where(fn[row])[0]
            plt.scatter(fn, torch.zeros_like(fn), label=f"FN={fn.size(0)}", marker='o',
                        color='cyan', zorder=-2, linewidths=1, edgsuffixes_tokd_colorecolor='black', alpha=.5)
        # edit positions
        if 'edits.positions' in data:
            x = data['edits.positions'].cpu()[row]
            plt.scatter(x, [y_max/2]*len(x), zorder=-2, linewidths=6,
                        color='black', marker='o', alpha=1)
            marker = {
                "INSERT": '^', "REPLACE": 'X', "REMOVE": 'v',
            }[data["config"]["edits.type"]]
            plt.scatter(x, [y_max/2]*len(x), zorder=-1,
                        label=data["config"]["edits.type"].lower(),
                        color='orange', marker=marker)
        # title and threshold
        edit_count = 0
        if 'edits.count' in data['config']:
            if hasattr(data['config']['edits.count'], '__getitem__'):
                edit_count = data['config']['edits.count'][row]
            else:
                edit_count = data['config']['edits.count']
        plt.axhline(0, zorder=-4, linewidth=2,
                    label="$|s|_E(t)=\\tau_e$", color='steelblue')
        if "edit_detection.eval" in data:
            st = st[row]
            t = data["config"]["edit_detection.tolerance"]
            t = min(t), max(t)
            plt.title(f"{edit_count} {data['config']['edits.type']}s: tolerance={t}, "+
                      f"$w$={data['config']['edit_detection.window']} "+
                      f"$\\rightarrow$ t1=${st[0]*100:.1f}\\%$, t2=${st[1]*100:.1f}\\%$")
            ncol = 4
        else:
            plt.title(f"Original: {len(xs)} tokens, $\\delta$={data['config']['watermark.delta']}")
            ncol = 3
        # s-score
        ys = data["wm_detection.scores"][row]
        ys = ys[1:] - ys[:-1]
        xs = torch.arange(ys.size(0)) + data["config"]["detection.window"] // 2
        plt.plot(xs, ys, zorder=-3, alpha=wm_score_alpha, color='white', linewidth=1,
                 linestyle='--', marker='o', markersize=2)

        plt.legend(ncol=ncol, loc='lower right', fontsize='smaller')
        plt.tight_layout()

#==== 10 ====#
def plot_edit_score(data, row, tok_count=14, wm_score_alpha=.75, legend=True):
    y_min, y_max = -.5, 1.5
    row_len = tok_count
    th = data["config"]["edit_detection.threshold"]\
        if "edit_detection.threshold" in data["config"] else 0.5
    # labels and limits
    plt.ylim(y_min, y_max)
    ys = data["edit_detection.scores"][row][:row_len] + .05
    xs = torch.arange(ys.size(0))
    plt.ylabel("Detection\nStatistics")
    yticks = [th, 1]
    plt.yticks(yticks, labels=[f"{t:.1f}" for t in yticks])
    plt.xlabel("Token position")
    plt.xticks(np.arange(row_len))
    plt.xlim(-.5, row_len-.5)
    # colors
    tokd_is_green = data['suffixes_tokd_color'][row]
    colors = [data["config"]["watermark.colors"][c] for c in tokd_is_green]
    for y in [min(-1, y_min), max(+1, y_max)]:
        plt.bar(xs, [y*10]*xs.size(0), zorder=-5, width=1, alpha=.1,
                color=colors,)
    # edit scores
    plt.bar(xs, ys, label="No edit detected",
            color='black', zorder=-4, width=.7, alpha=.25)
    ys_detected = ys < th
    plt.bar(xs[ys_detected], ys[ys_detected], label="Edit detected",
            color='tab:red', zorder=-3, width=.7)
    # evaluation
    if "edit_detection.eval" in data:
        tp, tn, fp, fn, st = data["edit_detection.eval"]
        tp = torch.where(tp[row][:row_len])[0]
        plt.scatter(tp, torch.full(tp.shape, 1.25), label=f"True Detection", marker='D',
                    color='white', zorder=-2, linewidths=1, edgecolor='black')
        fp = torch.where(fp[row][:row_len])[0]
        plt.scatter(fp, torch.full(fp.shape, 1.25), label=f"False Alarm", marker='D',
                    color='magenta', zorder=-2, linewidths=1, edgecolor='black')
        fn = torch.where(fn[row][:row_len])[0]
        plt.scatter(fn, torch.full(fn.shape, 1.25), label=f"Missed Detection", marker='o',
                    color='cyan', zorder=-2, linewidths=1, edgecolor='black')
    # edit positions
    if 'edits.positions' in data:
        x = data['edits.positions'][row]
        marker = 'X'
        plt.scatter(x.cpu(), [-.25]*len(x), zorder=-1, label="Actual edit",
                    color='tab:red', marker=marker)
    # title and threshold
    plt.axhline(th, linewidth=1, zorder=-5,
                label="Detection threshold", color='tab:red')
    plt.axhline(0, linewidth=1, zorder=-5, color='black')
    plt.tight_layout()
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    plt.gca().spines['left'].set_visible(False)
    plt.gca().spines['bottom'].set_visible(False)
    if legend:
        plt.legend(ncol=2, loc='center left', fontsize='smaller',
                   bbox_to_anchor=(1, .5))


def inline_diff_color(a: str, b: str) -> None:
    RED   = "\x1b[31m"
    GREEN = "\x1b[32m"
    RESET = "\x1b[0m"
    a_words = a.split()
    b_words = b.split()
    matcher = difflib.SequenceMatcher(None, a_words, b_words, autojunk=False)
    a_line_parts = []
    b_line_parts = []
    for tag, i1, i2, j1, j2 in matcher.get_opcodes():
        if tag == 'equal':
            chunk = ' '.join(a_words[i1:i2])
            a_line_parts.append(chunk)
            b_line_parts.append(chunk)
        elif tag == 'delete':
            # word(s) only in a → red on line 1
            deleted = ' '.join(a_words[i1:i2])
            a_line_parts.append(f"{RED}{deleted}{RESET}")
        elif tag == 'insert':
            # word(s) only in b → green on line 2
            inserted = ' '.join(b_words[j1:j2])
            b_line_parts.append(f"{GREEN}{inserted}{RESET}")
        elif tag == 'replace':
            # deletions on line 1, insertions on line 2
            deleted  = ' '.join(a_words[i1:i2])
            inserted = ' '.join(b_words[j1:j2])
            a_line_parts.append(f"{RED}{deleted}{RESET}")
            b_line_parts.append(f"{GREEN}{inserted}{RESET}")
    return ' '.join(a_line_parts), ' '.join(b_line_parts)

def apply_meaningful_edit(text_orig, gemini_api_key):
    client = genai.Client(
        api_key=gemini_api_key,
    )
    model = "gemini-2.5-flash-preview-04-17"
    contents = [
        types.Content(
            role="user",
            parts=[
                types.Part.from_text(text=text_orig),
            ],
        ),
    ]
    generate_content_config = types.GenerateContentConfig(
        response_mime_type="text/plain",
        thinking_config = types.ThinkingConfig(
            thinking_budget=0,
        ),
        system_instruction=[
            types.Part.from_text(
                text="You are an expert on grammar, journalism, and sentiment. "
                "I will give you a text, and you only job is to modify one subset of the text. "
                "Do not touch the beginning or end of text. Keep the changes towards the center. "
                "Keep the changes localized to only 1-3 words of the text. "
                "Do not change formatting, do not add bold/italic annotation, do not add prefixes or suffixes. "),
        ],
    )
    text_mod = []
    for chunk in client.models.generate_content_stream(
        model=model,
        contents=contents,
        config=generate_content_config,
    ):
        text_mod.append(chunk.text)
    return ''.join(text_mod)

#==== 11 ====#
def sweep_contiguous_edits_count(data, edit_types=("INSERT", "REPLACE", "REMOVE",),
                                 edit_counts=(1,2,3,4,5), batch_count=16, batch_size=64,
                                 return_edited_datas=False):
    df, result_datas = [], []
    data = steps_reduce([step_model_load, step_random_reset], data)
    for batch_idx in range(batch_count):
        data["config"]["data.lines"] = (batch_size*batch_idx, batch_size*(batch_idx+1))
        data_batch = steps_reduce([
            step_prefixes_load, step_tokenize,
            step_generate_watermarked, step_score_watermark, step_score_edit], data)
        result_datas.append(data_batch)
        
        tqdm_indicator = tqdm(sorted(itertools.product(edit_counts, edit_types)))
        for edit_count, edit_type in tqdm_indicator:
            tqdm_indicator.set_description(f"batch_idx={batch_idx}, {edit_type} {edit_count}")
            data_edited = copy.deepcopy(data_batch)
            data_edited["suffixes_tokd"] = data_edited["suffixes_tokd"].clone()
            data_edited["config"]["edits.type"] = edit_type
            data_edited["config"]["edits.count"] = edit_count
            data_edited = steps_reduce([
                step_edit_tokd_randomize_contiguous, step_score_watermark, step_score_edit, 
                step_bsearch_edit_detection_threshold_given_t1,
                step_eval_edit_detection_with_tolerance],
                data_edited)
            t1, t2, prec, recall, f1 = data_edited["edit_detection.eval"][-1].mean(dim=0).numpy()
            df.append({
                "batch_idx": batch_idx,
                "edits.type": edit_type,
                "edits.count": edit_count,
                "edits.frac": edit_count / data_edited["suffixes_tokd"].size(1),
                "wm_detection.scores": data_edited["wm_detection.scores"][:,-1].mean().item(),
                "edit_detection.scores": data_edited["edit_detection.scores"][:,-1].mean().item(),
                "edit_detection.threshold": data_edited["config"]["edit_detection.threshold"],
                "edit_detection.t1": t1.item(),
                "edit_detection.t2": t2.item(),
                "edit_detection.prec": prec.item(),
                "edit_detection.recall": recall.item(),
                "edit_detection.f1": f1.item(),
            })
            if return_edited_datas: result_datas.append(data_edited)
            cache_clear(data)
            tqdm_indicator.set_description(f"batch_idx={batch_idx} DONE")
    df = pd.DataFrame(df)
    df = df.sort_values(by=['edits.type', 'edits.count'])
    if data["config"].get("log", False): display(df)
    df = df.groupby(['edits.type', 'edits.count']).mean()
    df = df.drop(columns=['batch_idx'])
    if not return_edited_datas: return df
    return df, result_datas

#==== 12 ====#
def step_edit_tokd_randomize_contiguous_multiple(data):
    edits_type = data["config"]['edits.type']
    edits_count = data["config"]['edits.count']
    edits_width = data["config"]['edits.width']
    edits_buffer = data["config"]['edits.buffer']
    if "edit_detection.window" in data["config"]:
        assert edits_buffer >= edits_width + data["config"]["edit_detection.window"]
    T = data['suffixes_tokd']
    vocab_range = (0, data["tokenizer"].vocab_size)
    def random_edit_positions():
        if edits_type == "INSERT":
            min_dist = edits_buffer+1
            edits_end = T.size(1)-edits_buffer+1
            min_len = edits_buffer*(edits_count+1)
        elif edits_type == "REPLACE":
            min_dist = edits_buffer+edits_width
            edits_end = T.size(1) - edits_buffer - edits_width
            min_len = edits_buffer + edits_count*min_dist
        elif edits_type == "REMOVE":
            min_dist = edits_buffer+edits_width+1
            edits_end = T.size(1) - edits_buffer - edits_width - 1
            min_len = edits_buffer + edits_count*min_dist
        assert min_len <= T.size(1), f"model.token_count should be >= {min_len}"
        rs = None
        while rs is None:
            rs = sorted([
                random.randint(edits_buffer, edits_end)
                for _ in range(edits_count)])
            for i in range(1, len(rs)):
                if rs[i] - rs[i-1] < min_dist:
                    rs = None
                    break
        assert rs is not None, "Too many retries"
        return rs
    W = torch.arange(0, edits_width, device=T.device)
    P_narrow = torch.tensor([random_edit_positions() for _ in range(T.size(0))],
                            device=T.device, dtype=torch.int64)
    P = (P_narrow.unsqueeze(-1) + W).reshape(T.size(0), -1)
    V = torch.randint(
        low=vocab_range[0], high=vocab_range[1],
        size=P.shape, dtype=T.dtype, device=T.device)
    if edits_type == "INSERT":
        T = _data_insert(data, 'suffixes_tokd', P, V)
    elif edits_type == "REPLACE":
        T = _data_replace(data, 'suffixes_tokd', P, V)
    elif edits_type == "REMOVE":
        T, P = _data_remove(data, 'suffixes_tokd', P)
        P = P_narrow
        P -= torch.arange(P_narrow.size(1), device=T.device) * edits_width
    data['suffixes_tokd'] = T
    data['edits.positions'] = P
    data["config"]["model.token_count"] = T.size(1)
    mask = torch.zeros(T.shape, device=T.device).long()
    mask.scatter_(dim=1, index=P, value=True)
    data['edits.mask'] = mask
    if "suffixes_colors" in data: del data["suffixes_colors"]
    if "suffixes_logits" in data: del data["suffixes_logits"]
    if "suffixes_logits_unmodified" in data: del data["suffixes_logits_unmodified"]
    return data

#==== 13 ====#
def step_detect_edit(data):
    score = data["edit_detection.scores"].min(dim=-1)[0]
    th = data["config"]["edit_detection.threshold"]
    data["edit_detection.is_edited"] = (score < th).int()[:, None]
    return data

def _eval_is_edited(data, expected, detection_threshold):
    if 'config' not in data: data['config'] = {}
    data["config"]["edit_detection.threshold"] = detection_threshold
    data = step_detect_edit(data)
    count_incorrect = torch.sum(data["edit_detection.is_edited"] != expected).item()
    count_total = data["edit_detection.is_edited"].size(0)
    error = 1.0 * count_incorrect / count_total
    return error
def eval_is_edited(data_unmod, data_wtmkd, detection_threshold):
    t1 = _eval_is_edited(data_unmod, 0, detection_threshold)
    t2 = _eval_is_edited(data_wtmkd, 1, detection_threshold)
    return t1, t2

def _find_sublist(sub, main):
    sub = str(sub.item()) if len(sub.shape) == 0 else\
        ''.join([str(l) for l in sub.tolist()])
    idx = None
    for size in range(1, len(sub)+1):
        this_sub = sub[-size:]
        if this_sub in main:
            idx = main.index(this_sub) + size
    return idx if idx is not None else 0
def fn_pattern_using_prev(tokd_prev, data):
    wm_colors = data["config"]["watermark.colors"]
    wm_pattern = data["config"]["watermark.color_pattern"]
    assert min(wm_pattern) >= 0 and max(wm_pattern) <= len(wm_colors)
    if "watermark.color_pattern_str" not in data["config"]:
        wm_pattern_str = ''.join([str(i) for i in wm_pattern] * 2)
        data["config"]["watermark.color_pattern_str"] = wm_pattern_str
    else:
        wm_pattern_str = data["config"]["watermark.color_pattern_str"]
    if "suffixes_colors" in data:
        colors = data["suffixes_colors"][:, -1]
        tokd_col = colors[torch.arange(tokd_prev.size(0)), tokd_prev[:,-1]]
        tensor_cat(data, 'suffixes_tokd_color', tokd_col)
        prev_col = data['suffixes_tokd_color']
        next_col = [
            wm_pattern[_find_sublist(c, wm_pattern_str) % len(wm_pattern)]
            for c in prev_col]
        next_col = torch.tensor(next_col, device=tokd_prev.device)[:, None]
    else:
        pos = tokd_prev.size(1)
        col_idx = wm_pattern[pos % len(wm_pattern)]
        next_col = torch.full(
            (tokd_prev.size(0), 1), col_idx,
            device=tokd_prev.device)
    return next_col

def sweep_edit_window_size(data, ws=(2, 3, 4, 5, 6, 7),
                           t1_err_max=.1, return_edited_datas=False):
    data_wtmkd = steps_reduce(
        [step_random_reset, step_generate_watermarked], data)
    data_edited = steps_reduce(
        [step_edit_tokd_randomize_contiguous_multiple], data_wtmkd)
    df, result_datas = [], []
    for w in ws:
        data_wtmkd['config']['detection.window'] = 1 # doesn't matter
        data_wtmkd['config']['edit_detection.window'] = w
        data_wtmkd = steps_reduce(
            [step_score_watermark, step_score_edit], data_wtmkd)
        data_edited['config']['detection.window'] = 1 # doesn't matter
        data_edited['config']['edit_detection.window'] = w
        data_edited = steps_reduce(
            [step_score_watermark, step_score_edit], data_edited)
        th, t1, t2 = lsearch_edit_detection_threshold(
            data_wtmkd, data_edited, t1_err_max=t1_err_max)
        df.append({
            "detection.window": data_edited['config']['detection.window'],
            "edit_detection.window": data_edited['config']['edit_detection.window'],
            "edit_detection.threshold": th,
            "edit_detection.t1": t1,
            "edit_detection.t2": t2,
        })
        if return_edited_datas: result_datas.append(data_edited)
    df = pd.DataFrame(df)
    if not return_edited_datas: return df
    return df, result_datas

def step_lsearch_edit_detection_threshold_given_t1(data):
    best_th = None
    t1_target = data["config"]["edit_detection.target_t1"]
    EPS = .25 / data['config']['edit_detection.window']
    for th in np.arange(-EPS, 1+EPS*2, EPS):
        data["config"]["edit_detection.threshold"] = th
        data = step_eval_edit_detection_with_tolerance(data)
        t1, t2, _, _, _ = data["edit_detection.eval"][-1].mean(dim=0)
        if t1 <= t1_target: best_th = th
    data["config"]["edit_detection.threshold"] = best_th
    return step_eval_edit_detection_with_tolerance(data)
def sweep_edit_window_size_for_local_edit(
        data, ws=(2, 3, 4, 5, 6, 7, 8, 9), return_edited_datas=False):
    data = steps_reduce(
        [step_random_reset, step_generate_watermarked], data)
    data_edited = steps_reduce(
        [step_edit_tokd_randomize_contiguous_multiple], data)
    df, result_datas = [], []
    for w in ws:
        data_edited['config']['detection.window'] = 1 # doesn't matter
        data_edited['config']['edit_detection.window'] = w
        data_edited = steps_reduce([
            step_score_watermark, step_score_edit,
            step_bsearch_edit_detection_threshold_given_t1,
            step_eval_edit_detection_with_tolerance], data_edited)
        t1, t2, prec, recall, f1 = data_edited["edit_detection.eval"][-1].mean(dim=0).numpy()
        df.append({
            "detection.window": data_edited['config']['detection.window'],
            "edit_detection.window": data_edited['config']['edit_detection.window'],
            "edit_detection.threshold": data_edited["config"]["edit_detection.threshold"],
            "edit_detection.t1": t1.item(),
            "edit_detection.t2": t2.item(),
            "edit_detection.prec": prec.item(),
            "edit_detection.recall": recall.item(),
            "edit_detection.f1": f1.item(),
        })
        if return_edited_datas: result_datas.append(data_edited)
    df = pd.DataFrame(df)
    if not return_edited_datas: return df
    return df, result_datas

#==== 14 ====#
def sweep_contiguous_edits_multiple(
        data, edit_types=("INSERT", "REPLACE", "REMOVE",),
        config_key="edits.count", config_vals=(1,2,3,4,5,6,7),
        first_batch_idx=0, batch_count=8, batch_size=32, return_edited_datas=False):
    df, result_datas = [], []
    data = steps_reduce([step_model_load, step_random_reset], data)
    for batch_idx in range(first_batch_idx, first_batch_idx+batch_count):
        data["config"]["data.lines"] = (batch_size*batch_idx, batch_size*(batch_idx+1))
        data_batch = steps_reduce([
            step_prefixes_load, step_tokenize, step_generate_watermarked,
            step_score_watermark, step_score_edit, step_ppl_compute], data)
        result_datas.append(data_batch)
        
        tqdm_indicator = tqdm(sorted(itertools.product(config_vals, edit_types)))
        for config_val, edit_type in tqdm_indicator:
            tqdm_indicator.set_description(f"batch_idx={batch_idx}, {edit_type} {config_val}")
            data_edited = copy.deepcopy(data_batch)
            data_edited["suffixes_tokd"] = data_edited["suffixes_tokd"].clone()
            data_edited["config"]["edits.type"] = edit_type
            data_edited["config"][config_key] = config_val
            data_edited = steps_reduce([
                step_edit_tokd_randomize_contiguous_multiple,
                step_score_watermark, step_score_edit, 
                step_lsearch_edit_detection_threshold_given_t1,
                step_eval_edit_detection_with_tolerance],
                data_edited)
            t1, t2, prec, recall, f1 = data_edited["edit_detection.eval"][-1].mean(dim=0).numpy()
            df.append({
                "batch_idx": batch_idx,
                "edits.type": edit_type,
                config_key: config_val,
                "wm_detection.scores": data_edited["wm_detection.scores"][:,-1].mean().item(),
                "edit_detection.scores": data_edited["edit_detection.scores"][:,-1].mean().item(),
                "edit_detection.threshold": data_edited["config"]["edit_detection.threshold"],
                "edit_detection.t1": t1.item(),
                "edit_detection.t2": t2.item(),
                "edit_detection.prec": prec.item(),
                "edit_detection.recall": recall.item(),
                "edit_detection.f1": f1.item(),
                "base.wm_detection.scores": data_batch["wm_detection.scores"][:,-1].mean().item(),
                "base.ppl": data_batch["eval.ppl"],
                "watermark.delta": data_batch["config"]["watermark.delta"],
            })
            if return_edited_datas: result_datas.append(data_edited)
            data_edited = None
            cache_clear(data)
            tqdm_indicator.set_description(f"batch_idx={batch_idx} DONE")
    df = pd.DataFrame(df)
    df = df.sort_values(by=['edits.type', config_key])
    if data["config"].get("log", False): display(df)
    df = df.groupby(['edits.type', config_key]).mean()
    df = df.drop(columns=['batch_idx'])
    if not return_edited_datas: return df
    return df, result_datas

def plot_contiguous_edit_sweep(df, color, label):
    edit_types = sorted(set(df.index.get_level_values(0)))
    config_key = df.index.names[1]
    config_vals = sorted(set(df.index.get_level_values(1)))
    metrics = [
        ("edit_detection.t2", "Type II Error", (0, 1)),
        ("edit_detection.t1", "Type I Error", (.0, 1)),
        ("edit_detection.f1", "F-Score", (.0, 1)),
    ]
    for metric_idx, (metric, metric_name, metric_range) in enumerate(metrics):
        for type_idx, type in enumerate(edit_types):
            plt.subplot(len(metrics), len(edit_types), type_idx+1 + (metric_idx)*len(edit_types))
            if metric_idx == 0: plt.title(type)
            if metric_idx == len(metrics)-1: plt.xlabel(config_key)
            if type_idx == 0: plt.ylabel(metric_name)
            ys = df[df.index.get_level_values('edits.type') == type][metric]
            plt.plot(config_vals, ys, marker='X', linestyle='--', linewidth=2,
                     alpha=.5, color=color, label=label)
            plt.ylim(*metric_range)
            plt.xlim(min(config_vals), max(config_vals))
            plt.grid(True)
    plt.tight_layout(pad=1.1)
